#include <algorithm>
#include <cmath>

#include "CSE.h"
#include "IR.h"
#include "IROperator.h"
#include "Lerp.h"
#include "Simplify.h"
#include "Target.h"

namespace Halide {
namespace Internal {

Expr lower_lerp(Type final_type, Expr zero_val, Expr one_val, const Expr &weight, const Target &target) {

    Expr result;

    internal_assert(zero_val.type() == one_val.type());
    internal_assert(weight.type().is_uint() || weight.type().is_float())
        << "Bad weight type: " << weight.type() << "\n";

    Type result_type = zero_val.type();

    Expr bias_value = make_zero(result_type);
    Type computation_type = result_type;

    if (zero_val.type().is_int()) {
        computation_type = UInt(zero_val.type().bits(), zero_val.type().lanes());
        // We must take care to do the addition and subtraction of the
        // bias while in the unsigned computation type, where overflow
        // is well-defined.
        bias_value = cast(computation_type, result_type.min());
    }

    // For signed integer types, just convert everything to unsigned
    // and then back at the end to ensure proper rounding, etc.
    // There is likely a better way to handle this.
    if (result_type != computation_type) {
        zero_val = Cast::make(computation_type, zero_val) - bias_value;
        one_val = Cast::make(computation_type, one_val) - bias_value;
    }

    if (result_type.is_bool()) {
        Expr half_weight;
        if (weight.type().is_float()) {
            half_weight = 0.5f;
        } else {
            half_weight = weight.type().max() / 2;
        }

        result = select(weight > half_weight, one_val, zero_val);
    } else {
        Expr typed_weight;
        Expr inverse_typed_weight;

        if (weight.type().is_float()) {
            typed_weight = weight;
            if (computation_type.is_uint()) {
                // TODO: Verify this reduces to efficient code or
                // figure out a better way to express a multiply
                // of unsigned 2^32-1 by a double promoted weight
                if (computation_type.bits() == 32) {
                    typed_weight =
                        Cast::make(computation_type,
                                   cast<double>(Expr(65535.0f)) * cast<double>(Expr(65537.0f)) *
                                       Cast::make(Float(64, typed_weight.type().lanes()), typed_weight));
                } else {
                    typed_weight =
                        Cast::make(computation_type,
                                   computation_type.max() * typed_weight);
                }
                inverse_typed_weight = computation_type.max() - typed_weight;
            } else {
                inverse_typed_weight = make_one(computation_type) - typed_weight;
            }

        } else {
            if (computation_type.is_float()) {
                int weight_bits = weight.type().bits();
                Expr denom = make_const(computation_type, (ldexp(1.0, weight_bits) - 1));
                typed_weight = Cast::make(computation_type, weight) / denom;
                inverse_typed_weight = make_one(computation_type) - typed_weight;
            } else {
                // This code rescales integer weights to the right number of bits.
                // It takes advantage of (2^n - 1) == (2^(n/2) - 1)(2^(n/2) + 1)
                // e.g. 65535 = 255 * 257. (Ditto for the 32-bit equivalent.)
                // To recale a weight of m bits to be n bits, we need to do:
                //     scaled_weight = (weight / (2^m - 1)) * (2^n - 1)
                // which power of two values for m and n, results in a series like
                // so:
                //     (2^(m/2) + 1) * (2^(m/4) + 1) ... (2^(n*2) + 1)
                // The loop below computes a scaling constant and either multiples
                // or divides by the constant and relies on lowering and llvm to
                // generate efficient code for the operation.
                int bit_size_difference = weight.type().bits() - computation_type.bits();
                if (bit_size_difference == 0) {
                    typed_weight = weight;
                } else {
                    typed_weight = Cast::make(computation_type, weight);

                    int bits_left = ::abs(bit_size_difference);
                    int shift_amount = std::min(computation_type.bits(), weight.type().bits());
                    uint64_t scaling_factor = 1;
                    while (bits_left != 0) {
                        internal_assert(bits_left > 0);
                        scaling_factor = scaling_factor + (scaling_factor << shift_amount);
                        bits_left -= shift_amount;
                        shift_amount *= 2;
                    }
                    if (bit_size_difference < 0) {
                        typed_weight =
                            Cast::make(computation_type, weight) *
                            cast(computation_type, (int32_t)scaling_factor);
                    } else {
                        typed_weight =
                            Cast::make(computation_type,
                                       weight / cast(weight.type(), (int32_t)scaling_factor));
                    }
                }
                inverse_typed_weight =
                    Cast::make(computation_type,
                               computation_type.max() - typed_weight);
            }
        }

        if (computation_type.is_float()) {
            result = (zero_val * inverse_typed_weight +
                      one_val * typed_weight);
        } else {
            int32_t bits = computation_type.bits();
            switch (bits) {
            case 1:
                result = select(typed_weight, one_val, zero_val);
                break;
            case 8:
            case 16:
            case 32: {
                Expr prod_sum = (widening_mul(zero_val, inverse_typed_weight) +
                                 widening_mul(one_val, typed_weight));
                // Now we need to do a rounding divide and narrow. For
                // 8-bit, this rounding divide looks like (x + 127) /
                // 255. On most platforms it's we can compute this as
                // ((x + 128) / 256 + x + 128) / 256. Note that
                // overflow is impossible here because the most our
                // prod_sum can be is 255^2.
                if (target.arch == Target::X86) {
                    // On x86 we have no rounding shifts but we do
                    // have a multiply-keep-high-half. So it's
                    // actually one instruction cheaper to do the
                    // division directly.
                    Expr divisor = cast(UInt(bits), -1);
                    result = (prod_sum + divisor / 2) / divisor;
                } else {
                    result = rounding_shift_right(rounding_shift_right(prod_sum, bits) + prod_sum, bits);
                }
                break;
            }
            case 64:
                // TODO: 64-bit lerp is not supported as current approach
                // requires double-width multiply.
                // There is an informative error message in IROperator.h.
                internal_error << "Can't do a 64-bit lerp.\n";
                break;
            default:
                break;
            }

            if (weight.type().is_float()) {
                // Insert an explicit cast to the computation type, even if
                // we're going to widen, because out-of-range floats can produce
                // out-of-range outputs.
                result = Cast::make(computation_type, result);
            }
        }

        if (!is_const_zero(bias_value)) {
            result = Cast::make(result_type, result + bias_value);
        }
    }

    if (result.type() != final_type) {
        result = Cast::make(final_type, result);
    }

    return simplify(common_subexpression_elimination(result));
}

}  // namespace Internal
}  // namespace Halide
